{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Causal Discovery example\n",
    "\n",
    "The goal of this notebook is to show how causal discovery methods can work with DoWhy. We use discovery methods from [Causal Discovery Tool (CDT)](https://github.com/FenTechSolutions/CausalDiscoveryToolbox) repo. As we will see, causal discovery methods are not fool-proof and there is no guarantee that they will recover the correct causal graph. Even for the simple examples below, there is a large variance in results. These methods, however, may be combined usefully with domain knowledge to construct the final causal graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dowhy\n",
    "from dowhy import CausalModel\n",
    "\n",
    "from rpy2.robjects import r as R\n",
    "%load_ext rpy2.ipython\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import graphviz\n",
    "import networkx as nx \n",
    "\n",
    "np.set_printoptions(precision=3, suppress=True)\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utility function\n",
    "We define a utility function to draw the directed acyclic graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_graph(adjacency_matrix, labels=None):\n",
    "    idx = np.abs(adjacency_matrix) > 0.01\n",
    "    dirs = np.where(idx)\n",
    "    d = graphviz.Digraph(engine='dot')\n",
    "    names = labels if labels else [f'x{i}' for i in range(len(adjacency_matrix))]\n",
    "    for name in names:\n",
    "        d.node(name)\n",
    "    for to, from_, coef in zip(dirs[0], dirs[1], adjacency_matrix[idx]):\n",
    "        d.edge(names[from_], names[to], label=str(coef))\n",
    "    return d\n",
    "\n",
    "def str_to_dot(string):\n",
    "    '''\n",
    "    Converts input string from graphviz library to valid DOT graph format.\n",
    "    '''\n",
    "    graph = string.strip().replace('\\n', ';').replace('\\t','')\n",
    "    graph = graph[:9] + graph[10:-2] + graph[-1] # Removing unnecessary characters from string\n",
    "    return graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiments on the Auto-MPG dataset\n",
    "\n",
    "In this section, we will use a dataset on the technical specification of cars. The dataset is downloaded from UCI Machine Learning Repository. The dataset contains 9 attributes and 398 instances. We do not know the true causal graph for the dataset and will use CDT to discover it. The causal graph obtained will then be used to estimate the causal effect.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_mpg = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data-original',\n",
    "                   delim_whitespace=True, header=None,\n",
    "                   names = ['mpg', 'cylinders', 'displacement',\n",
    "                            'horsepower', 'weight', 'acceleration',\n",
    "                            'model year', 'origin', 'car name'])\n",
    "data_mpg.dropna(inplace=True)\n",
    "data_mpg.drop(['model year', 'origin', 'car name'], axis=1, inplace=True)\n",
    "print(data_mpg.shape)\n",
    "data_mpg.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Causal Discovery with Causal Discovery Tool (CDT)\n",
    "\n",
    "We use the CDT library to perform causal discovery on the Auto-MPG dataset. We use three methods for causal discovery here -LiNGAM, PC and GES. These methods are widely used and do not take much time to run. Hence, these are ideal for an introduction to the topic. Other neural network based methods are also available in CDT and the users are encouraged to try them out by themselves. \n",
    "\n",
    "The documentation for the methods used are as follows:\n",
    "- LiNGAM [[link]](https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/_modules/cdt/causality/graph/LiNGAM.html)\n",
    "- PC [[link]](https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/_modules/cdt/causality/graph/PC.html)\n",
    "- GES [[link]](https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/_modules/cdt/causality/graph/GES.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cdt.causality.graph import LiNGAM, PC, GES\n",
    "\n",
    "graphs = {}\n",
    "labels = [f'{col}' for i, col in enumerate(data_mpg.columns)]\n",
    "functions = {\n",
    "    'LiNGAM' : LiNGAM,\n",
    "    'PC' : PC,\n",
    "    'GES' : GES,\n",
    "}\n",
    "\n",
    "for method, lib in functions.items():\n",
    "    obj = lib()\n",
    "    output = obj.predict(data_mpg)\n",
    "    adj_matrix = nx.to_numpy_matrix(output)\n",
    "    adj_matrix = np.asarray(adj_matrix)\n",
    "    graph_dot = make_graph(adj_matrix, labels)\n",
    "    graphs[method] = graph_dot\n",
    "\n",
    "# Visualize graphs\n",
    "for method, graph in graphs.items():\n",
    "    print(\"Method : %s\"%(method))\n",
    "    display(graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, no two methods agree on the graphs. PC and GES effectively produce an undirected graph whereas LiNGAM produces a directed graph. We use only the LiNGAM method in the next section."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Estimate causal effects using Linear Regression\n",
    "\n",
    "Now let us see whether these differences in the graphs also lead to signficant differences in the causal estimate of effect of *mpg* on *weight*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for method, graph in graphs.items():\n",
    "        if method != \"LiNGAM\":\n",
    "            continue\n",
    "        print('\\n*****************************************************************************\\n')\n",
    "        print(\"Causal Discovery Method : %s\"%(method))\n",
    "        \n",
    "        # Obtain valid dot format\n",
    "        graph_dot = str_to_dot(graph.source)\n",
    "\n",
    "        # Define Causal Model\n",
    "        model=CausalModel(\n",
    "                data = data_mpg,\n",
    "                treatment='mpg',\n",
    "                outcome='weight',\n",
    "                graph=graph_dot)\n",
    "\n",
    "        # Identification\n",
    "        identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)\n",
    "        print(identified_estimand)\n",
    "        \n",
    "        # Estimation\n",
    "        estimate = model.estimate_effect(identified_estimand,\n",
    "                                        method_name=\"backdoor.linear_regression\",\n",
    "                                        control_value=0,\n",
    "                                        treatment_value=1,\n",
    "                                        confidence_intervals=True,\n",
    "                                        test_significance=True)\n",
    "        print(\"Causal Estimate is \" + str(estimate.value))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As mentioned earlier, due to the absence of directed edges, no backdoor, instrmental or frontdoor variables can be found out for PC and GES. Thus, causal effect estimation is not possible for these methods. However, LiNGAM does discover a DAG and hence, its possible to output a causal estimate for LiNGAM. The estimate is still pretty far from the original estimate of -70.466 (which can be calculated from the graph)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiments on the Sachs dataset\n",
    "\n",
    "The dataset consists of the simultaneous measurements of 11 phosphorylated proteins and phospholipids derived from thousands of individual primary immune system cells, subjected to both general and specific molecular interventions (Sachs et al., 2005).\n",
    "\n",
    "The specifications of the dataset are as follows - \n",
    "- Number of nodes: 11\n",
    "- Number of arcs: 17\n",
    "- Number of parameters: 178\n",
    "- Average Markov blanket size: 3.09\n",
    "- Average degree: 3.09\n",
    "- Maximum in-degree: 3\n",
    "- Number of instances: 7466\n",
    "\n",
    "The original causal graph is known for the Sachs dataset and we compare the original graph with the ones discovered using CDT in this section."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cdt.data import load_dataset\n",
    "data_sachs, graph_sachs = load_dataset(\"sachs\")\n",
    "\n",
    "data_sachs.dropna(inplace=True)\n",
    "print(data_sachs.shape)\n",
    "data_sachs.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ground truth of the causal graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = [f'{col}' for i, col in enumerate(data_sachs.columns)]\n",
    "adj_matrix = nx.to_numpy_matrix(graph_sachs)\n",
    "adj_matrix = np.asarray(adj_matrix)\n",
    "graph_dot = make_graph(adj_matrix, labels)\n",
    "display(graph_dot)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Causal Discovery with Causal Discovery Tool (CDT)\n",
    "\n",
    "We use the CDT library to perform causal discovery on the Auto-MPG dataset. We use three methods for causal discovery here -LiNGAM, PC and GES. These methods are widely used and do not take much time to run. Hence, these are ideal for an introduction to the topic. Other neural network based methods are also available in CDT and the users the encourages to try them out by themselves. \n",
    "\n",
    "The documentation for the methods used in as follows:\n",
    "- LiNGAM [[link]](https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/_modules/cdt/causality/graph/LiNGAM.html)\n",
    "- PC [[link]](https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/_modules/cdt/causality/graph/PC.html)\n",
    "- GES [[link]](https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/_modules/cdt/causality/graph/GES.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cdt.causality.graph import LiNGAM, PC, GES\n",
    "\n",
    "graphs = {}\n",
    "graphs_nx = {}\n",
    "labels = [f'{col}' for i, col in enumerate(data_sachs.columns)]\n",
    "functions = {\n",
    "    'LiNGAM' : LiNGAM,\n",
    "    'PC' : PC,\n",
    "    'GES' : GES,\n",
    "}\n",
    "\n",
    "for method, lib in functions.items():\n",
    "    obj = lib()\n",
    "    output = obj.predict(data_sachs)\n",
    "    graphs_nx[method] = output\n",
    "    adj_matrix = nx.to_numpy_matrix(output)\n",
    "    adj_matrix = np.asarray(adj_matrix)\n",
    "    graph_dot = make_graph(adj_matrix, labels)\n",
    "    graphs[method] = graph_dot\n",
    "\n",
    "# Visualize graphs\n",
    "for method, graph in graphs.items():\n",
    "    print(\"Method : %s\"%(method))\n",
    "    display(graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, no two methods agree on the graphs. Next we study the causal effects of these different graphs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Estimate effects using Linear Regression\n",
    "\n",
    "Now let us see whether these differences in the graphs also lead to signficant differences in the causal estimate of effect of *PIP2* on *PKC*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for method, graph in graphs.items():\n",
    "        if method != \"LiNGAM\":\n",
    "            continue\n",
    "        print('\\n*****************************************************************************\\n')\n",
    "        print(\"Causal Discovery Method : %s\"%(method))\n",
    "\n",
    "        # Obtain valid dot format\n",
    "        graph_dot = str_to_dot(graph.source)\n",
    "\n",
    "        # Define Causal Model\n",
    "        model=CausalModel(\n",
    "                data = data_sachs,\n",
    "                treatment='PIP2',\n",
    "                outcome='PKC',\n",
    "                graph=graph_dot)\n",
    "\n",
    "        # Identification\n",
    "        identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)\n",
    "        print(identified_estimand)\n",
    "\n",
    "        # Estimation\n",
    "        estimate = model.estimate_effect(identified_estimand,\n",
    "                                        method_name=\"backdoor.linear_regression\",\n",
    "                                        control_value=0,\n",
    "                                        treatment_value=1,\n",
    "                                        confidence_intervals=True,\n",
    "                                        test_significance=True)\n",
    "        print(\"Causal Estimate is \" + str(estimate.value))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the causal estimates obtained, it can be seen that the three estimates differ in different aspects. The graph obtained using LiNGAM contains a backdoor path and instrumental variables. On the other hand, the graph obtained using PC contains a backdoor path and a frontdoor path. However, despite these differences, both obtain the same mean causal estimate.\n",
    "\n",
    "The graph obtained using GES contains only a backdoor path with different backdoor variables and obtains a different causal estimate than the first two cases. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Graph Validation\n",
    "\n",
    "We compare the graphs obtained with the true causal graph using the causal discovery methods using 2 graph distance metrics - Structural Hamming Distance (SHD) and Structural Intervention Distance (SID). SHD between two graphs is, in simple terms, the number of edge insertions, deletions or flips in order to transform one graph to another graph. SID, on the other hand, is based on a graphical criterion only and quantifies the closeness between two DAGs in terms of their corresponding causal inference statements."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cdt.metrics import SHD, SHD_CPDAG, SID, SID_CPDAG\n",
    "from numpy.random import randint\n",
    "\n",
    "for method, graph in graphs_nx.items():\n",
    "    print(\"***********************************************************\")\n",
    "    print(\"Method: %s\"%(method))\n",
    "    tar, pred = graph_sachs, graph\n",
    "    print(\"SHD_CPDAG = %f\"%(SHD_CPDAG(tar, pred)))\n",
    "    print(\"SHD = %f\"%(SHD(tar, pred, double_for_anticausal=False)))\n",
    "    print(\"SID_CPDAG = [%f, %f]\"%(SID_CPDAG(tar, pred)))\n",
    "    print(\"SID = %f\"%(SID(tar, pred)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The graph similarity metrics show that the scores are the lowest for the LiNGAM method of graph extraction. Hence, of the three methods used, LiNGAM provides the graph that is most similar to the original graph."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Graph Refutation\n",
    "\n",
    "Here, we use the same SHD and SID metric to find out how different the discovered graph are from each other."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "from numpy.random import randint\n",
    "from cdt.metrics import SHD, SHD_CPDAG, SID, SID_CPDAG\n",
    "\n",
    "# Find combinations of pair of methods to compare\n",
    "combinations = list(itertools.combinations(graphs_nx, 2))\n",
    "\n",
    "for pair in combinations:\n",
    "    print(\"***********************************************************\")\n",
    "    graph1 = graphs_nx[pair[0]]\n",
    "    graph2 = graphs_nx[pair[1]]\n",
    "    print(\"Methods: %s and %s\"%(pair[0], pair[1]))\n",
    "    print(\"SHD_CPDAG = %f\"%(SHD_CPDAG(graph1, graph2)))\n",
    "    print(\"SHD = %f\"%(SHD(graph1, graph2, double_for_anticausal=False)))\n",
    "    print(\"SID_CPDAG = [%f, %f]\"%(SID_CPDAG(graph1, graph2)))\n",
    "    print(\"SID = %f\"%(SID(graph1, graph2)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The values for the metrics show how different the graphs are from each other. A higher distance value implies that the difference between the graphs is more."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "metadata": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
